#! /usr/bin/env python3

import re
from copy import deepcopy
import logging
import operator
import numpy as np

import urlearning.BayesianNetwork as BayesianNetwork

logger = logging.getLogger(__name__)

# This class implements a simplified version of the score cache for URLearning.

def parse(line, start, delimiters):
    line = line[start:]
    tokens = list(filter(None, re.split(delimiters, line)))
    return tokens

def parseMetaInformation(line):
    return parse(line, 4, "=")

def parseVariableValues(line):
    return parse(line, 0, ",| ")

def get_parents_list(parent_mask, num_variables):
    """ This function checks each bit in the parent mask to see if the
        respective parent is present in the set.

        N.B. This is not implemented particularly efficiently.
        
        Args:
            parent_mask (np.uint64): a bit mask indicating the parent set
            
            num_variables (int): the number of variable in the entire dataset
            
        Returns:
            list of ints: the parents indicated by the parent set,
                as a list of integers
    """
    parents = []
    # check for each possible parent
    for p in range(num_variables):
        # check if this parent bit is set
        pa = np.uint64(1 << p)
        if parent_mask & pa:
            parents.append(p)
            
    return parents

def _write_jkl_score(score, parents, num_variables, f):
    """ This is an internal helper function for writing a single local
        score in jkl format to a file. It is not intended for external
        use.
    """
    # make sure the score is negative
    if score > 0:
        score *= -1
        
    # first the score
    score_str = "{} ".format(score)
    f.write(score_str)
    
    # then the parents
    parent_list = get_parents_list(parents, num_variables)

    num_parents = len(parent_list)
    f.write(str(num_parents))

    # and a space
    f.write(' ')

    parents_str = ' '.join(str(p) for p in parent_list)
    f.write(parents_str)
    
    f.write("\n")

class ScoreCache:
    
    # constructor
    def __init__(self, filename="", file_type='pss'):
        self.variableCount = 0
        self.parents = []
        self.scores = []
        self.metaInformation = {}
        self.network = BayesianNetwork.BayesianNetwork()

        if len(filename) > 0:
            if file_type == 'jkl':
                self.read_jkl(filename)
            else:
                self.read(filename)

    def getVariableCount(self):
        return self.variableCount

    def __len__(self):
        return self.variableCount

    def getScoreCount(self):
        count = 0
        for c in self.parents:
            count += len(c)
        return count

    def getMaxScoreCount(self):
        """ This method returns the largest number of parents for any single
            variable in the cache.
        """
        m = max(len(c) for c in self.parents)
        return m
        
    def putMetaInformation(self, key, value):
        self.metaInformation[key] = value

    def getMetaInformation(self, key):
        if key in self.metaInformation:
            return self.metaInformation[key]

        return None

    def addCache(self, cache):
        # sort the scores by value
        parents = np.zeros(len(cache), dtype=np.uint64)
        scores = np.zeros(len(cache))
        index = 0
        for pair in sorted(cache.items(), key=lambda kv: (kv[1], kv[0]) ):
            parents[index] = pair[0]
            scores[index] = pair[1]
            index += 1

        self.parents.append(parents)
        self.scores.append(scores)


    def read(self, filename):
        
        f = open(filename)
        
        # read in the meta information
        for line in f:
            
            # skip empty lines and comments
            line = line.strip().lower()
            if len(line) == 0 or line.startswith("#"):
                continue
                
            # check if we reached the first variable
            if line.find("var ") > -1:
                break
                
            # make sure this is a "meta" keyword
            if line.find("meta") == -1:
                msg = ("Error while parsing META information of network.  Expected META line "
                    "or Variable.  Line: '{}'".format(line))
                raise SyntaxError(msg)
                
            # get the meta information
            tokens = parseMetaInformation(line)

            if len(tokens) != 2:
                msg = ("Error while parsing META information of network.  Too many tokens. "
                    "Line: '{}'".format(line))
                raise SyntaxError(msg)
            
            tokens[0] = tokens[0].strip()
            tokens[1] = tokens[1].strip()
            self.metaInformation[tokens[0]] = tokens[1]
            
        # line currently points to a variable name
        tokens = parse(line, 0, " ")
        v = self.network.addVariable(tokens[1])
        
        # read in the variable names
        for line in f:
        
            # skip empty lines and comments
            line = line.strip().lower()
            if len(line) == 0 or line.startswith("#"):
                continue
                
            if line.find("meta") > -1:
                tokens = parseMetaInformation(line)
                
                if tokens[0].find("arity") > -1:
                    v.setArity(int(tokens[1]))
                    
                elif tokens[0].find("values") > -1:
                    values = parseVariableValues(tokens[1])
                    v.setValues(values)
                    
                else:
                    tokens[0] = tokens[0].strip()
                    tokens[1] = tokens[1].strip()
                    v.putMetaInformation(tokens[0], tokens[1])
                    
            if line.find("var ") > -1:
                tokens = parse(line, 0, " ")
                v = self.network.addVariable(tokens[1])
                
        f.close()
        self.variableCount = self.network.getVariableCount()

        # make sure we can represent the variable sets
        if self.variableCount > 64:
            msg = ("Error while reading score cache.  Too many variables: '{}' (maximum is 64)".format(
                self.variableCount))
            raise OverflowError(msg)
        
        
        # now that we have the variable names and indices, read in the parent sets
        f = open(filename)

        curCache = None
        
        for line in f:
            
            # skip empty lines, comments and meta keywords
            line = line.strip().lower()
            if len(line) == 0 or line.startswith("#") or line.find("meta") > -1:
                continue
                
            tokens = parse(line, 0, " ")
            if line.find("var ") > -1:
                if curCache != None:
                    self.addCache(curCache)
                v = self.network.getByName(tokens[1])
                curCache = {}
                continue
                
            # then parse the score for the current variable
            parents = np.uint64(0)
            score = float(tokens[0])
            
            for i in range(1, len(tokens)):
                index = self.network.getVariableIndex(tokens[i])
                parents |= np.uint64(1 << index)
            
            # make sure the score is positive
            if score < 0:
                score = -1 * score

            if logger.isEnabledFor(logging.DEBUG):
                if parents in curCache:
                    msg = ("Found duplicate entry. Variable: {}, parents: {}".
                        format(v, parents))
                    logger.warning(msg)

            curCache[parents] = score
            
        f.close()

        # add the last cache
        self.addCache(curCache)

    def read_jkl(self, filename):

        with open(filename) as f:

            # the first line is the number of variabes
            num_variables = int(f.readline())

            self.variableCount = num_variables

            # make sure we can represent the variable sets
            if self.variableCount > 64:
                msg = ("Error while reading score cache.  Too many variables: '{}' (maximum is 64)".format(
                    self.variableCount))
                raise OverflowError(msg)


            # add the variables to the network
            for v in range(num_variables):
                self.network.addVariable("Variable" + str(v))

            # and read the actual caches
            cur_cache = None
            for v in range(num_variables):

                # the next line has the variable index and the number of scores for that variable
                l = f.readline()
                s = l.strip().split(" ")

                msg = "Expecting to read new variable info. Found: {}".format(l)
                logger.debug(msg)

                var_index = s[0]
                score_count = int(s[1])

                v = self.network.getByName("Variable" + var_index)
            
                if cur_cache != None:
                    self.addCache(cur_cache)
                cur_cache = {}

                for i in range(score_count):
                    s = f.readline().strip().split(" ")
                    
                    # each line is of the form:
                    #   score |pa_i| p1 p2 ... pm ||| e.g.  -106.565548505 3 13 15 11

                    # make sure the score is positive
                    score = float(s[0])
                    if score < 0:
                        score = -1 * score

                    # and read in the parents
                    parents = np.uint64(0)
                    for p in range(2, len(s)):
                        parent = "Variable" + s[p]
                        index = self.network.getVariableIndex(parent)
                        parents |= np.uint64(1 << index)

                    cur_cache[parents] = score

        # add the last cache
        self.addCache(cur_cache)


        
        
    def write(self, filename):
        
        f = open(filename, "w")
        
        # first, all of the meta information
        for key in self.metaInformation.keys():
            value = self.metaInformation[key]
            f.write("META {}={}\n".format(key, value))
        
        # and a blank line
        f.write("\n")
        
        # then each variable
        for x in range(self.variableCount):
            v = self.network.get(x)
            f.write("VAR {}\n".format(v.getName()))
            f.write("META arity={}\n".format(v.getArity()))
            f.write("META values=")
            
            for val in v.getValues():
                f.write(" {}".format(val))
            f.write("\n")
            
            for key in v.getMetaInformationKeys():
                val = v.getMetaInformation(key)
                f.write("META {}={}\n".format(key, val))
            
            # now all of the scores
            for i in range(len(self.parents[x])):
                parents = self.parents[x][i]
                f.write("{}".format(self.scores[x][i]))
                
                for p in range(self.variableCount):
                    pa = np.uint64(1 << p)
                    if parents & pa:
                        par = self.network.get(p)
                        f.write(" {}".format(par.getName()))
                
                f.write("\n")
            
            # and a blank line
            f.write("\n")
        
        f.close()


    def write_jkl(self, filename):
        """ This function writes the score cache to the specified file in jkl format.
        """

        with open(filename, 'w') as f:
            # first, the number of variables
            f.write(str(self.variableCount))
            f.write("\n")

            # now, the scores for each variable
            pars_scores = zip(self.parents, self.scores)
            for x, (parents_x, scores_x) in enumerate(pars_scores):
                num_parent_sets = len(parents_x)
                header = "{} {}\n".format(x, num_parent_sets)
                f.write(header)
                
                for i in range(num_parent_sets):
                    _write_jkl_score(scores_x[i], parents_x[i], self.variableCount, f)

    def getBestScore(self, x, parents):
        for i in range( len(self.parents[x]) ):
            if self.parents[x][i] & parents == self.parents[x][i]:
                return self.scores[x][i]

        msg = "Parent set not in score cache.  Variable: '{}', Parents: '{}'".format(x, parents)
        raise KeyError(msg)
